{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn import * \n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>job</th>\n",
       "      <th>marital</th>\n",
       "      <th>education</th>\n",
       "      <th>default</th>\n",
       "      <th>balance</th>\n",
       "      <th>housing</th>\n",
       "      <th>loan</th>\n",
       "      <th>contact</th>\n",
       "      <th>day</th>\n",
       "      <th>month</th>\n",
       "      <th>duration</th>\n",
       "      <th>campaign</th>\n",
       "      <th>pdays</th>\n",
       "      <th>previous</th>\n",
       "      <th>poutcome</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>58</td>\n",
       "      <td>management</td>\n",
       "      <td>married</td>\n",
       "      <td>tertiary</td>\n",
       "      <td>no</td>\n",
       "      <td>2143</td>\n",
       "      <td>yes</td>\n",
       "      <td>no</td>\n",
       "      <td>unknown</td>\n",
       "      <td>5</td>\n",
       "      <td>may</td>\n",
       "      <td>261</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>unknown</td>\n",
       "      <td>no</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>44</td>\n",
       "      <td>technician</td>\n",
       "      <td>single</td>\n",
       "      <td>secondary</td>\n",
       "      <td>no</td>\n",
       "      <td>29</td>\n",
       "      <td>yes</td>\n",
       "      <td>no</td>\n",
       "      <td>unknown</td>\n",
       "      <td>5</td>\n",
       "      <td>may</td>\n",
       "      <td>151</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>unknown</td>\n",
       "      <td>no</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>33</td>\n",
       "      <td>entrepreneur</td>\n",
       "      <td>married</td>\n",
       "      <td>secondary</td>\n",
       "      <td>no</td>\n",
       "      <td>2</td>\n",
       "      <td>yes</td>\n",
       "      <td>yes</td>\n",
       "      <td>unknown</td>\n",
       "      <td>5</td>\n",
       "      <td>may</td>\n",
       "      <td>76</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>unknown</td>\n",
       "      <td>no</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>47</td>\n",
       "      <td>blue-collar</td>\n",
       "      <td>married</td>\n",
       "      <td>unknown</td>\n",
       "      <td>no</td>\n",
       "      <td>1506</td>\n",
       "      <td>yes</td>\n",
       "      <td>no</td>\n",
       "      <td>unknown</td>\n",
       "      <td>5</td>\n",
       "      <td>may</td>\n",
       "      <td>92</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>unknown</td>\n",
       "      <td>no</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>33</td>\n",
       "      <td>unknown</td>\n",
       "      <td>single</td>\n",
       "      <td>unknown</td>\n",
       "      <td>no</td>\n",
       "      <td>1</td>\n",
       "      <td>no</td>\n",
       "      <td>no</td>\n",
       "      <td>unknown</td>\n",
       "      <td>5</td>\n",
       "      <td>may</td>\n",
       "      <td>198</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>unknown</td>\n",
       "      <td>no</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age           job  marital  education default  balance housing loan  \\\n",
       "0   58    management  married   tertiary      no     2143     yes   no   \n",
       "1   44    technician   single  secondary      no       29     yes   no   \n",
       "2   33  entrepreneur  married  secondary      no        2     yes  yes   \n",
       "3   47   blue-collar  married    unknown      no     1506     yes   no   \n",
       "4   33       unknown   single    unknown      no        1      no   no   \n",
       "\n",
       "   contact  day month  duration  campaign  pdays  previous poutcome   y  \n",
       "0  unknown    5   may       261         1     -1         0  unknown  no  \n",
       "1  unknown    5   may       151         1     -1         0  unknown  no  \n",
       "2  unknown    5   may        76         1     -1         0  unknown  no  \n",
       "3  unknown    5   may        92         1     -1         0  unknown  no  \n",
       "4  unknown    5   may       198         1     -1         0  unknown  no  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"/data/bank-full.csv\", sep=\";\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 45211 entries, 0 to 45210\n",
      "Data columns (total 17 columns):\n",
      " #   Column     Non-Null Count  Dtype \n",
      "---  ------     --------------  ----- \n",
      " 0   age        45211 non-null  int64 \n",
      " 1   job        45211 non-null  object\n",
      " 2   marital    45211 non-null  object\n",
      " 3   education  45211 non-null  object\n",
      " 4   default    45211 non-null  object\n",
      " 5   balance    45211 non-null  int64 \n",
      " 6   housing    45211 non-null  object\n",
      " 7   loan       45211 non-null  object\n",
      " 8   contact    45211 non-null  object\n",
      " 9   day        45211 non-null  int64 \n",
      " 10  month      45211 non-null  object\n",
      " 11  duration   45211 non-null  int64 \n",
      " 12  campaign   45211 non-null  int64 \n",
      " 13  pdays      45211 non-null  int64 \n",
      " 14  previous   45211 non-null  int64 \n",
      " 15  poutcome   45211 non-null  object\n",
      " 16  y          45211 non-null  object\n",
      "dtypes: int64(7), object(10)\n",
      "memory usage: 5.9+ MB\n"
     ]
    }
   ],
   "source": [
    "df.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "no     39922\n",
       "yes     5289\n",
       "Name: y, dtype: int64"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.y.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "no     0.883015\n",
       "yes    0.116985\n",
       "Name: y, dtype: float64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.y.value_counts()/len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>balance</th>\n",
       "      <th>day</th>\n",
       "      <th>duration</th>\n",
       "      <th>campaign</th>\n",
       "      <th>pdays</th>\n",
       "      <th>previous</th>\n",
       "      <th>job_blue-collar</th>\n",
       "      <th>job_entrepreneur</th>\n",
       "      <th>job_housemaid</th>\n",
       "      <th>...</th>\n",
       "      <th>month_jun</th>\n",
       "      <th>month_mar</th>\n",
       "      <th>month_may</th>\n",
       "      <th>month_nov</th>\n",
       "      <th>month_oct</th>\n",
       "      <th>month_sep</th>\n",
       "      <th>poutcome_other</th>\n",
       "      <th>poutcome_success</th>\n",
       "      <th>poutcome_unknown</th>\n",
       "      <th>y_yes</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>58</td>\n",
       "      <td>2143</td>\n",
       "      <td>5</td>\n",
       "      <td>261</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>44</td>\n",
       "      <td>29</td>\n",
       "      <td>5</td>\n",
       "      <td>151</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>33</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>76</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>47</td>\n",
       "      <td>1506</td>\n",
       "      <td>5</td>\n",
       "      <td>92</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>33</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>198</td>\n",
       "      <td>1</td>\n",
       "      <td>-1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 43 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  balance  day  duration  campaign  pdays  previous  job_blue-collar  \\\n",
       "0   58     2143    5       261         1     -1         0                0   \n",
       "1   44       29    5       151         1     -1         0                0   \n",
       "2   33        2    5        76         1     -1         0                0   \n",
       "3   47     1506    5        92         1     -1         0                1   \n",
       "4   33        1    5       198         1     -1         0                0   \n",
       "\n",
       "   job_entrepreneur  job_housemaid  ...  month_jun  month_mar  month_may  \\\n",
       "0                 0              0  ...          0          0          1   \n",
       "1                 0              0  ...          0          0          1   \n",
       "2                 1              0  ...          0          0          1   \n",
       "3                 0              0  ...          0          0          1   \n",
       "4                 0              0  ...          0          0          1   \n",
       "\n",
       "   month_nov  month_oct  month_sep  poutcome_other  poutcome_success  \\\n",
       "0          0          0          0               0                 0   \n",
       "1          0          0          0               0                 0   \n",
       "2          0          0          0               0                 0   \n",
       "3          0          0          0               0                 0   \n",
       "4          0          0          0               0                 0   \n",
       "\n",
       "   poutcome_unknown  y_yes  \n",
       "0                 1      0  \n",
       "1                 1      0  \n",
       "2                 1      0  \n",
       "3                 1      0  \n",
       "4                 1      0  \n",
       "\n",
       "[5 rows x 43 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_dummy = pd.get_dummies(df, drop_first=True)\n",
    "df_dummy.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = \"y_yes\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n",
       "                       max_depth=4, max_features=None, max_leaf_nodes=None,\n",
       "                       min_impurity_decrease=0.0, min_impurity_split=None,\n",
       "                       min_samples_leaf=100, min_samples_split=250,\n",
       "                       min_weight_fraction_leaf=0.0, presort='deprecated',\n",
       "                       random_state=None, splitter='best')"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = df_dummy[target]\n",
    "X = df_dummy.drop(columns=target)\n",
    "\n",
    "X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size = 0.3, random_state = 1)\n",
    "\n",
    "cls = tree.DecisionTreeClassifier(\n",
    "    max_depth=4, \n",
    "    min_samples_leaf= 100,\n",
    "    min_samples_split= 250\n",
    ")\n",
    "\n",
    "cls.fit(X_train, y_train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training accuracy:  0.9016652447309381\n",
      "test accuracy:  0.9009879091713359\n",
      "training precision:  0.6382508833922261\n",
      "test precision:  0.608786610878661\n",
      "training recall:  0.38657035848047083\n",
      "test recall:  0.37524177949709864\n"
     ]
    }
   ],
   "source": [
    "y_train_pred = cls.predict(X_train)\n",
    "y_test_pred = cls.predict(X_test)\n",
    "\n",
    "y_test_prob = cls.predict_proba(X_test)[:, 1]\n",
    "\n",
    "print(\"training accuracy: \", metrics.accuracy_score(y_train, y_train_pred))\n",
    "print(\"test accuracy: \", metrics.accuracy_score(y_test, y_test_pred))\n",
    "print(\"training precision: \", metrics.precision_score(y_train, y_train_pred))\n",
    "print(\"test precision: \", metrics.precision_score(y_test, y_test_pred))\n",
    "print(\"training recall: \", metrics.recall_score(y_train, y_train_pred))\n",
    "print(\"test recall: \", metrics.recall_score(y_test, y_test_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.09901209082866413"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 - 0.9009879091713359"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[11639,   374],\n",
       "       [  969,   582]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = metrics.confusion_matrix(y_test, y_test_pred)\n",
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(11639, 582, 374, 969)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TN = m[0][0]\n",
    "FP = m[0][1]\n",
    "FN = m[1][0]\n",
    "TP = m[1][1]\n",
    "TN, TP, FP, FN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.37524177949709864"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall = TP / (TP + FN)\n",
    "recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.608786610878661"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision = TP / (TP + FP)\n",
    "precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training accuracy:  0.9016652447309381\n",
      "test accuracy:  0.9009879091713359\n",
      "training precision:  0.6382508833922261\n",
      "test precision:  0.608786610878661\n",
      "training recall:  0.38657035848047083\n",
      "test recall:  0.37524177949709864\n"
     ]
    }
   ],
   "source": [
    "y_test_pred = np.where(y_test_prob>0.5, 1, 0)\n",
    "\n",
    "print(\"training accuracy: \", metrics.accuracy_score(y_train, y_train_pred))\n",
    "print(\"test accuracy: \", metrics.accuracy_score(y_test, y_test_pred))\n",
    "print(\"training precision: \", metrics.precision_score(y_train, y_train_pred))\n",
    "print(\"test precision: \", metrics.precision_score(y_test, y_test_pred))\n",
    "print(\"training recall: \", metrics.recall_score(y_train, y_train_pred))\n",
    "print(\"test recall: \", metrics.recall_score(y_test, y_test_pred))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training accuracy:  0.9016652447309381\n",
      "test accuracy:  0.8561633736360955\n",
      "training precision:  0.6382508833922261\n",
      "test precision:  0.41928974979822436\n",
      "training recall:  0.38657035848047083\n",
      "test recall:  0.6698903932946486\n",
      "training f1:  0.4815061646117961\n",
      "test f1:  0.5157607346736163\n"
     ]
    }
   ],
   "source": [
    "y_test_pred = np.where(y_test_prob>0.1, 1, 0)\n",
    "\n",
    "print(\"training accuracy: \", metrics.accuracy_score(y_train, y_train_pred))\n",
    "print(\"test accuracy: \", metrics.accuracy_score(y_test, y_test_pred))\n",
    "print(\"training precision: \", metrics.precision_score(y_train, y_train_pred))\n",
    "print(\"test precision: \", metrics.precision_score(y_test, y_test_pred))\n",
    "print(\"training recall: \", metrics.recall_score(y_train, y_train_pred))\n",
    "print(\"test recall: \", metrics.recall_score(y_test, y_test_pred))\n",
    "print(\"training f1: \", metrics.f1_score(y_train, y_train_pred))\n",
    "print(\"test f1: \", metrics.f1_score(y_test, y_test_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "f1 = 2 * p * r / (p + r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.04592271, 0.18364611, 0.18471338, 0.19745223, 0.36373166,\n",
       "       0.41666667, 0.44907407, 0.51394422, 0.56441718, 0.57439446,\n",
       "       0.57843137, 0.66743119, 0.75502008, 0.79850746])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(y_test_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "14"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(np.unique(y_test_prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import export_graphviz\n",
    "export_graphviz(cls, out_file = \"tree.dot\", feature_names = X.columns, filled=True)\n",
    "!dot -Tpng tree.dot -o tree.png"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3637316561844864"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "694/1908"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6674311926605505"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "291/436"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5743944636678201"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "332/578"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 5 folds for each of 120 candidates, totalling 600 fits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.\n",
      "[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    1.4s\n",
      "[Parallel(n_jobs=8)]: Done 320 tasks      | elapsed:    5.2s\n",
      "[Parallel(n_jobs=8)]: Done 585 out of 600 | elapsed:    8.5s remaining:    0.2s\n",
      "[Parallel(n_jobs=8)]: Done 600 out of 600 | elapsed:    8.7s finished\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=5, error_score=nan,\n",
       "             estimator=DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None,\n",
       "                                              criterion='gini', max_depth=4,\n",
       "                                              max_features=None,\n",
       "                                              max_leaf_nodes=None,\n",
       "                                              min_impurity_decrease=0.0,\n",
       "                                              min_impurity_split=None,\n",
       "                                              min_samples_leaf=100,\n",
       "                                              min_samples_split=250,\n",
       "                                              min_weight_fraction_leaf=0.0,\n",
       "                                              presort='deprecated',\n",
       "                                              random_state=None,\n",
       "                                              splitter='best'),\n",
       "             iid='deprecated', n_jobs=8,\n",
       "             param_grid={'criterion': ['gini', 'entropy'],\n",
       "                         'max_depth': array([2, 3, 4, 5]),\n",
       "                         'min_samples_leaf': array([ 5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])},\n",
       "             pre_dispatch='2*n_jobs', refit=True, return_train_score=False,\n",
       "             scoring='accuracy', verbose=True)"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_grid = {\n",
    "    \"max_depth\": np.arange(2, 6),\n",
    "    \"criterion\": [\"gini\", \"entropy\"],\n",
    "    \"min_samples_leaf\": np.arange(5, 20)\n",
    "}\n",
    "\n",
    "\n",
    "gsearch =model_selection.GridSearchCV(cls, param_grid=param_grid, scoring=\"accuracy\"\n",
    "                                      , cv = 5, verbose = True, n_jobs= 8)\n",
    "gsearch.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9023288147375739, 0.9006223415023051)"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gsearch.score(X_train, y_train), gsearch.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'criterion': 'gini', 'max_depth': 5, 'min_samples_leaf': 16}"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gsearch.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
